import numpy as np
import matplotlib.pyplot as plt


def calcular_aceleracao_drag_gravidade(v, g, vT):
    """
    Calcula a aceleração total devido à resistência do ar (drag) e gravidade.

    Parâmetros:
        v (np.ndarray): vetor velocidade (3D) [vx, vy, vz].
        g (float): aceleração da gravidade (m/s²).
        vT (float): velocidade terminal (m/s).

    Retorna:
        np.ndarray: vetor aceleração (3D).
    """
    D = g / vT**2  # coeficiente de arrasto
    norm_v = np.linalg.norm(v)
    a_drag = -D * norm_v * v
    a_grav = np.array([0, 0, -g])
    return a_drag + a_grav


def simular_trajetoria_euler_3d(pos_inicial, vel_inicial, g, vT, dt, tf):
    """
    Simula a trajetória 3D de uma bola de ténis usando o método de Euler.

    Parâmetros:
        pos_inicial (tuple): posição inicial (x0, y0, z0) em metros.
        vel_inicial (tuple): velocidade inicial (vx0, vy0, vz0) em m/s.
        g (float): aceleração da gravidade (m/s²).
        vT (float): velocidade terminal (m/s).
        dt (float): passo de tempo (s).
        tf (float): tempo final (s).

    Retorna:
        tuple: (tempo, posições, velocidades)
            - tempo (np.ndarray)
            - posicoes (np.ndarray shape (n, 3))
            - velocidades (np.ndarray shape (n, 3))
    """
    n = int(tf / dt)
    t = np.zeros(n + 1)
    r = np.zeros((n + 1, 3))
    v = np.zeros((n + 1, 3))

    r[0] = np.array(pos_inicial)
    v[0] = np.array(vel_inicial)

    for i in range(n):
        t[i + 1] = t[i] + dt

        a = calcular_aceleracao_drag_gravidade(v[i], g, vT)

        # Atualizar velocidade e posição
        v[i + 1] = v[i] + a * dt
        r[i + 1] = r[i] + v[i] * dt

        # Verifica se a bola atingiu o solo (z <= 0)
        if r[i + 1, 2] <= 0:
            print(f"A bola atinge o solo em t = {t[i + 1]:.2f} s")
            return t[:i + 2], r[:i + 2], v[:i + 2]

    return t, r, v


def simular_trajetoria_tenis_3d(pos_inicial, vel_inicial, v_terminal, dt, tf):
    """
    Simula a trajetória de uma bola de tênis usando o método de Euler em 3D.
    Considera a força da gravidade e a resistência do ar proporcional à velocidade.

    Parâmetros:
    - pos_inicial: tupla (x0, y0, z0) posição inicial em metros
    - vel_inicial: tupla (vx0, vy0, vz0) velocidade inicial em m/s
    - v_terminal: velocidade terminal em m/s (padrão: 120 km/h)
    - dt: passo de tempo em segundos (padrão: 0.01)
    - tf: tempo final da simulação em segundos (padrão: 2.0)

    Retorna:
    - dicionário com arrays de tempo, posição (x, y, z) e velocidade (vx, vy, vz)
    """
    g = 9.8
    D = g / v_terminal**2
    n = int(tf / dt)

    x0, y0, z0 = pos_inicial
    vx0, vy0, vz0 = vel_inicial

    t = np.zeros(n+1)
    x = np.zeros(n+1); y = np.zeros(n+1); z = np.zeros(n+1)
    vx = np.zeros(n+1); vy = np.zeros(n+1); vz = np.zeros(n+1)
    ax = np.zeros(n+1); ay = np.zeros(n+1); az = np.zeros(n+1)

    x[0], y[0], z[0] = x0, y0, z0
    vx[0], vy[0], vz[0] = vx0, vy0, vz0

    for i in range(n):
        t[i+1] = t[i] + dt
        vv = np.sqrt(vx[i]**2 + vy[i]**2 + vz[i]**2)

        ax[i] = -D * vv * vx[i]
        ay[i] = -D * vv * vy[i]
        az[i] = -g - D * vv * vz[i]

        vx[i+1] = vx[i] + ax[i] * dt
        vy[i+1] = vy[i] + ay[i] * dt
        vz[i+1] = vz[i] + az[i] * dt

        x[i+1] = x[i] + vx[i] * dt
        y[i+1] = y[i] + vy[i] * dt
        z[i+1] = z[i] + vz[i] * dt

        if z[i+1] <= 0:
            print(f"A bola atinge o solo em t = {t[i+1]:.2f} s")
            x = x[:i+2]; y = y[:i+2]; z = z[:i+2]; t = t[:i+2]
            vx = vx[:i+2]; vy = vy[:i+2]; vz = vz[:i+2]
            break

    #return {
    #    't': t, 'x': x, 'y': y, 'z': z,
    #    'vx': vx, 'vy': vy, 'vz': vz,
    #    'ax': ax[:len(t)], 'ay': ay[:len(t)], 'az': az[:len(t)]
    #}
    
    return x, y, z, vx, vy, vz


def plotar_trajetoria_3d(x,y,z):
    """
    Plota a trajetória 3D da bola.

    Parâmetros:
        r (np.ndarray): posições da bola ao longo do tempo (n, 3).
    """
    fig = plt.figure(figsize=(10, 6))
    ax = plt.axes(projection='3d')
    ax.plot(x,y,z, label='Trajetória 3D')
    ax.set_xlabel('x (m)')
    ax.set_ylabel('y (m)')
    ax.set_zlabel('z (m)')
    ax.set_title('Simulação da trajetória (Euler 3D)')
    ax.legend()
    plt.show()


# ======================
# EXEMPLO DE USO
# ======================

if __name__ == "__main__":
    # Parâmetros iniciais
    pos_inicial = (0, 2, 3)  # posição em metros
    vel_inicial = (160000 / 3600, 20000 / 3600, -20000 / 3600)  # m/s
    g = 9.8
    vT = 120000 / 3600  # velocidade terminal em m/s
    dt = 0.01
    tf = 2.0
    m = 0.057   #massa da bola (kg)


    #t, r, v = simular_trajetoria_euler_3d(pos_inicial, vel_inicial, g, vT, dt, tf)
    x, y, z, vx, vy, vz = simular_trajetoria_tenis_3d(pos_inicial, vel_inicial, g, vT, dt, tf)
    plotar_trajetoria_3d(x,y,z)


    #energia mecanica:
    # Energia cinética: (1/2) m (vx² + vy² + vz²)
    # Energia potencial: m g z
    E = 0.5 * m * (vx**2 + vy**2 + vz**2) + m * g * z

    print("Energia mecânica inicial: {:.3f} J".format(E[0]))
    print("Energia mecânica no impacto: {:.3f} J".format(E[-1]))